//import java.util.HashMap;
//import java.util.HashSet;
//import java.util.List;
//import java.util.Map;
//import java.util.Set;
//import java.util.concurrent.ConcurrentHashMap;
//import java.util.function.Consumer;
//import org.apache.spark.sql.Row;
//
//public class BayesClassifier {
//
//	// 训练集总的词数
//	private Long totalWordCount = new Long(0);
//
//	// 类别
//	private Map<String, String> classMap = new HashMap<String, String>();
//
//	// 类别对应的文章数
//	private Map<String, Long> classArticleCount = new ConcurrentHashMap<String, Long>();
//
//	// 每个类别词的数量
//	private Map<String, Long> classWordCount = new ConcurrentHashMap<String, Long>();
//
//	// 每个类别对应的词典和词频
//	private Map<String, Map<String, Long>> classWordMap = new ConcurrentHashMap<String, Map<String, Long>>();
//
//	// 存放所有出现过的词
//	private Set<String> allWordSet = new HashSet<String>();
//
//	/**
//	 * 训练数据
//	 * @param records 每个Record存储文章的标题、内容和类别等信息
//	 */
//	public void train(List<Row> records) {
//
//		records.forEach(new Consumer<Row>() {
//			@Override
//			public void accept(Row record) {
//				// 文章的类别
//				String category = record.getString(4);
//				// 新的类别
//				if (!classMap.containsKey(category)) {
//					classMap.put(category, category);
//					classArticleCount.put(category, 0L);
//					classWordCount.put(category, 0L);
//					classWordMap.put(category, new HashMap<String, Long>());
//				}
//
//				Map<String, Long> wordMap = classWordMap.get(category);
//				// 获取切分的词
//				String[] words = record.getString(3).split(" ");
//				for (String word : words) {
//					// 更新该类别的词典和词频
//					if (wordMap.containsKey(word)) {
//						Long wordCount = wordMap.get(word);
//						wordMap.put(word, wordCount + 1);
//					} else {
//						wordMap.put(word, 1L);
//					}
//					allWordSet.add(word);
//				}
//				// 更新该类别的词典和词频
//				Long wordCount = classWordCount.get(category);
//				classWordCount.put(category, wordCount + words.length);
//
//				totalWordCount += words.length;
//			}
//		});
//	}
//
//	/**
//	 * @param classKey 类别
//	 * @param word 词
//	 * @return 类别中词出现的次数
//	 */
//	public Long wordInClassCount(String classKey, String word) {
//		Map<String, Long> wordMap = classWordMap.get(classKey);
//		Long wordCount = wordMap.get(word);
//		return (wordCount == null) ? 1L : wordCount;
//	}
//
//	/**
//	 * 选择分类概率最大的类别
//	 * @param probClassMap
//	 * @return 返回分类结果
//	 */
//	public String getMaxClassification(Map<String, Double> resultMap) {
//		Set<String> keySet = resultMap.keySet();
//		String maxClassification = null;
//		double maxProbability = Double.NEGATIVE_INFINITY;
//		// 选择归类概率最大的类别作为分类的结果
//		for (String classKey : keySet) {
//			double probability = resultMap.get(classKey);
//			if (probability > maxProbability) {
//				maxProbability = probability;
//				maxClassification = classKey;
//			}
//		}
//		return maxClassification;
//	}
//
//	/**
//	 * 对文章进行分类
//	 * @param record 待分类的文章
//	 * @return 分类的类别
//	 */
//	public String classify(Row record) {
//		// 获取文章的分词结果
//		String[] words = record.getString(3).split(" ");
//		Map<String, Double> resultMap = new HashMap<String, Double>();
//		Set<String> keySet = classMap.keySet();
//		// 计算文章属于每个类别的概率
//		for (String classKey : keySet) {
//			double probability = 0.0;
//			for (String word: words) {
//				double wordFrequency = wordInClassCount(classKey, word) * 1.0 / (classWordCount.get(classKey)+ + allWordSet.size());
//				probability += Math.log(wordFrequency);
//			}
//			probability += Math.log(classWordCount.get(classKey) * 1.0 / totalWordCount);
//			resultMap.put(classKey, probability);
//		}
//		// 选择分类结果并返回
//		return getMaxClassification(resultMap);
//	}
//}
